CREATE OR REPLACE FUNCTION "pipeline"."dbscan"("source_table" text, "out_table" varchar, "id_col" varchar, "feature_cols" varchar, "eps" float8, "min_samples" int4, "metric" varchar)
  RETURNS "pg_catalog"."text" AS $BODY$
  import numpy as np
  from sklearn.cluster import DBSCAN
  import json
  import time

  #初始化返回值
  result = {
    "status": 0,
    "error_msg": "success",
    "result": {
      "input_params": {
        "out_table": out_table,
        "source_table": source_table,
        "id_col": id_col,
        "feature_cols": [],
        "eps": eps,
        "min_samples": min_samples,
        "metric": metric
      },
      "output_params":[]
    }
  }

  #读数据
  def get_data_sql(source_table, id_col, feature_cols):  
    sql_str = "select %s, %s from %s" %(id_col, feature_cols, source_table)
    return sql_str
  
  #获取表格属性信息
  def getTableMetaInfo(table_name): 
    tmps = table_name.split(".")
    if len(tmps) > 0:
      table_name = tmps[1]	
    sql_str = "select column_name, data_type from information_schema.columns where table_name = '%s'" %(table_name)
    rs = plpy.execute(sql_str)
    meta = {}
    for line in rs:
      meta[line['column_name']] = line['data_type']
    return meta
  
  #创建输出表
  def saveOutTable(datas, out_table, id_col, meta, features):
    value_tpl = "({}, {})"
    if meta.get(id_col) in ["text", "character varying", "varchar", "char", "date", "character", 
        "timestamp", "time", "timestamp_with_timezone", "longvarchar", "longnvarchar", "nvarchar", "nchar"]:
      value_tpl = "('{}', {})"
    #table_name = "%s_%s" %(out_table, int(time.time()))
    table_name = out_table	
    fieldAndTypes = []
    for fea in features:
        fieldAndTypes.append("{} {}".format(fea, meta.get(fea)))
    sql_str = "CREATE TABLE %s (%s %s, %s, cluster_id varchar);" %(table_name, id_col, meta.get(id_col), ",".join(fieldAndTypes))
    plpy.execute(sql_str)
    sql_str = "INSERT INTO %s (%s, %s, cluster_id) VALUES" %(table_name, id_col, ",".join(features))
    _values = []
    for item in datas:
      tmpList = [str(x) for x in item]
      tmp = "({})".format(",".join(tmpList))
      #tmp = value_tpl.format(item[0], item[1])
      _values.append(tmp)
    sql_str += ",".join(_values)
    plpy.execute(sql_str)
    cols = [id_col]
    cols.extend(features)
    cols.append("cluster_id")
    return {"out_table_name": table_name, "cols": cols}

  def throwError(e):
    global result
    result["status"] = 500
    result["error_msg"] = str(e)
  
  #主程序
  def begin_alg(source_table, out_table, id_col, feature_cols, eps, min_samples, metric):
    global result
    sourceTable = source_table
    if source_table.startswith("create view") or source_table.startswith("CREATE VIEW"):
        try:
            rs = plpy.execute(source_table)
            sourceTable = source_table.split(" ")[2]
        except Exception as e:
            plpy.execute("DROP VIEW IF EXISTS {}".format(sourceTable))
            msg = "sq={}, errorMsg={}".format(source_table, str(e))
            result["status"] = 500
            result["error_msg"] = msg
            return json.dumps(result)
    sql_str = get_data_sql(sourceTable, id_col, feature_cols)
    features = feature_cols.split(",")
    result["result"]["input_params"]["feature_cols"] = features
    ids = []
    rv = None
    try:
      rv = plpy.execute(sql_str)
    except Exception as e:
      throwError(e)
      result["test"] = sql_str
      return result
    datas = []
    for item in rv:
      line = []
      ids.append(item[id_col])
      for feature in features:
        line.append(float(item[feature]))
      datas.append(line)
    datas = np.array(datas, dtype=np.float64)
    dbs = None
    try:
      dbs = DBSCAN(eps = eps, min_samples = min_samples, metric = metric).fit(datas)
    except Exception as e:
      throwError(e)
      result["dbscan"] = "error"	  
      return result
    labels = dbs.labels_.tolist()
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)

    out_datas = []
    for i in range(len(ids)):
      lineData = []
      lineData.append(ids[i])
      for j in range(len(features)):
        lineData.append(datas[i][j])
      lineData.append(str(labels[i]))
      out_datas.append(lineData)
    meta = getTableMetaInfo(sourceTable)
    saveRet = None
    try:
      saveRet = saveOutTable(out_datas, out_table, id_col, meta, features)
    except Exception as e:
      throwError(e)
      result["saveError"] = meta	  
      return result
    outTable = {"out_table_name": saveRet["out_table_name"], "output_cols": saveRet["cols"], "n_clusters": n_clusters}
    result["result"]["output_params"].append(outTable)	
    #result["result"]["output_params"]["n_clusters"] = n_clusters
    #result["result"]["output_params"]["out_table_name"] = saveRet["out_table_name"]
    #result["result"]["output_params"]["output_cols"] = saveRet["cols"]

    return json.dumps(result)
  
  if __name__ == '__main__':
    global result
    ret = begin_alg(source_table, out_table, id_col, feature_cols, eps, min_samples, metric)
    return ret
$BODY$
  LANGUAGE plpythonu VOLATILE
  COST 100